#include <torch/csrc/jit/codegen/cuda/mutator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>

#include <vector>

namespace torch {
namespace jit {
namespace fuser {

void OptOutMutator::mutate(Fusion* fusion) {
  std::vector<Expr*> orig_exprs = fusion->exprs();

  /*
   * We go through all the exprs, in topologically sorted order. We call mutate
   * on them which could insert nodes, removes nodes, or both. These operations
   * modify the dag and the Fusion will keep track of what has/hasn't been
   * changed by the origin dependency tracking that it does. If an operation is
   * added, and its output node is a val which previously was the output of
   * another expresion, that older expresion will be removed as we can only
   * assign a Val once due to our SSA restriction. Therefore we don't need to
   * manually track what expressions stayed constant or were changed.
   */

  for (Statement* stmt : orig_exprs)
    mutate(stmt);
}

// MUTATE FUNCTIONS FOR VALS

Statement* OptOutMutator::mutate(IterDomain* id) {
  Val* s = mutateAsVal(id->size())->asVal();
  if (!s->sameAs(id->size())) {
    Val* mutated_val =
        new IterDomain(s, id->parallel_method(), id->isReduction());
    registerMutation(id, mutated_val);
    return mutated_val;
  }
  return id;
}

Statement* OptOutMutator::mutate(TensorDomain* td) {
  std::vector<IterDomain*> dom;
  bool mutated = false;
  for (decltype(td->size()) i = 0; i < td->size(); i++) {
    IterDomain* id = static_cast<IterDomain*>(mutateAsVal(td->axis(i)));
    dom.push_back(id);
    if (!id->sameAs(td->axis(i)))
      mutated = true;
  }

  if (mutated) {
    Val* mutated_val = new TensorDomain(dom);
    registerMutation(td, mutated_val);
    return mutated_val;
  }
  return td;
}

Statement* OptOutMutator::mutate(TensorView* tv) {
  TensorDomain* td = static_cast<TensorDomain*>(mutateAsVal(tv->domain()));

  TensorView* computeAtView = nullptr;
  if (tv->hasComputeAt())
    computeAtView =
        static_cast<TensorView*>(mutateAsVal(tv->getComputeAtView()));

  if (!tv->domain()->sameAs(td) ||
      (tv->hasComputeAt() && !tv->getComputeAtView()->sameAs(computeAtView))) {
    TensorView* mutated_tv = new TensorView(td, tv->getDataType().value());
    if (tv->hasComputeAt()) {
      mutated_tv->setComputeAt(computeAtView, (int)(tv->getComputeAtAxis()));
    }
    registerMutation(tv, mutated_tv);
    return mutated_tv;
  }
  return tv;
}

Statement* OptOutMutator::mutate(TensorIndex* ti) {
  std::vector<Statement*> inds;
  for (auto* ind : ti->indices())
    inds.push_back(mutateAsVal(ind));

  bool changed = false;
  for (decltype(inds.size()) i{0}; i < inds.size(); i++) {
    TORCH_INTERNAL_ASSERT(inds[i]->isVal() && inds[i]->asVal()->isAnInt());
    if (!inds[i]->sameAs(ti->index(i)))
      changed = true;
  }

  if (!changed)
    return ti;

  std::vector<Val*> valInds(inds.size(), nullptr);
  for (decltype(inds.size()) i{0}; i < inds.size(); i++)
    valInds[i] = inds[i]->asVal();

  Val* mutated_val = new TensorIndex(ti->view(), valInds);
  registerMutation(ti, mutated_val);
  return mutated_val;
}

Statement* OptOutMutator::mutate(Float* n) {
  return n;
}
Statement* OptOutMutator::mutate(Int* n) {
  return n;
}
Statement* OptOutMutator::mutate(NamedScalar* n) {
  return n;
}

// MUTATE FUNCTIONS FOR EXPRESSIONS.

Statement* OptOutMutator::mutate(Allocate* a) {
  TensorView* tv = static_cast<TensorView*>(mutateAsVal(a->buffer()));
  Val* ext = mutateAsVal(a->extent())->asVal();
  if (ext->sameAs(a->extent()) && tv->sameAs(a->buffer()))
    return a;
  FusionGuard::getCurFusion()->removeExpr(a);
  return new Allocate(tv, ext);
}

Statement* OptOutMutator::mutate(Split* s) {
  TensorDomain* o = static_cast<TensorDomain*>(mutateAsVal(s->out()));
  TensorDomain* i = static_cast<TensorDomain*>(mutateAsVal(s->in()));
  Int* fact = static_cast<Int*>(mutateAsVal(s->factor()));

  if (o->sameAs(s->out()) && i->sameAs(s->in()) && fact->sameAs(s->factor()))
    return s;
  FusionGuard::getCurFusion()->removeExpr(s);
  return new Split(o, i, s->axis(), fact);
}

Statement* OptOutMutator::mutate(Merge* m) {
  TensorDomain* o = static_cast<TensorDomain*>(mutateAsVal(m->out()));
  TensorDomain* i = static_cast<TensorDomain*>(mutateAsVal(m->in()));

  if (o->sameAs(m->out()) && i->sameAs(m->in()))
    return m;

  FusionGuard::getCurFusion()->removeExpr(m);
  return new Merge(o, i, m->axis());
}

Statement* OptOutMutator::mutate(Reorder* ro) {
  TensorDomain* o = static_cast<TensorDomain*>(mutateAsVal(ro->out()));
  TensorDomain* i = static_cast<TensorDomain*>(mutateAsVal(ro->in()));

  if (o->sameAs(ro->out()) && i->sameAs(ro->in()))
    return ro;

  FusionGuard::getCurFusion()->removeExpr(ro);
  return new Reorder(o, i, ro->pos2axis());
}

Statement* OptOutMutator::mutate(UnaryOp* uop) {
  Val* out = mutateAsVal(uop->out())->asVal();
  Val* in = mutateAsVal(uop->in())->asVal();

  if (out->sameAs(uop->out()) && in->sameAs(uop->in()))
    return uop;
  FusionGuard::getCurFusion()->removeExpr(uop);
  return new UnaryOp(uop->getUnaryOpType(), out, in);
}

Statement* OptOutMutator::mutate(BinaryOp* bop) {
  Val* out = mutateAsVal(bop->out())->asVal();
  Val* lhs = mutateAsVal(bop->lhs())->asVal();
  Val* rhs = mutateAsVal(bop->rhs())->asVal();
  if (out == bop->out() && lhs == bop->lhs() && rhs == bop->rhs())
    return bop;
  FusionGuard::getCurFusion()->removeExpr(bop);
  return new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs);
}

Statement* OptOutMutator::mutate(ForLoop* n) {
  return n;
}

Statement* OptOutMutator::mutate(IfThenElse* n) {
  return n;
}

// START REPLACE ALL

void ReplaceAll::replaceInpOut() {
  Fusion* fusion = FusionGuard::getCurFusion();
  for (auto it : mutations) {
    Val* val = it.first;
    if (fusion->hasInput(val)) {
      fusion->replaceInput(it.first, it.second);
    } else if (fusion->hasOutput(val)) {
      fusion->replaceOutput(it.first, it.second);
    }
  }
}

void ReplaceAll::instancesOf(Val* instance, Val* with) {
  Fusion* fusion = FusionGuard::getCurFusion();
  std::unordered_map<Val*, Val*> replacement_map;
  replacement_map[instance] = with;
  ReplaceAll::instancesOf(replacement_map);
}

void ReplaceAll::instancesOf(std::unordered_map<Val*, Val*> replacement_map) {
  Fusion* fusion = FusionGuard::getCurFusion();

  ReplaceAll ra(std::move(replacement_map));
  // Get a copy because this will be modified in place, we shouldn't auto
  // iterate on it
  std::vector<Expr*> to_mutate;
  for (Expr* expr : fusion->unordered_exprs())
    to_mutate.push_back(expr);

  for (Expr* expr : to_mutate)
    ra.mutate(expr);

  ra.replaceInpOut();
}

void ReplaceAll::instancesWithin(Val* instance, Val* with, Expr* within) {
  if (within == nullptr)
    return;
  FusionGuard fg(within->fusion());
  ReplaceAll ra(instance, with);
  ra.mutate(within);
}

void ReplaceAll::instancesWithin(
    std::unordered_map<Val*, Val*> replacement_map,
    Expr* within) {
  if (within == nullptr)
    return;
  FusionGuard fg(within->fusion());
  ReplaceAll ra(std::move(replacement_map));
  ra.mutate(within);
}

} // namespace fuser
} // namespace jit
} // namespace torch
